{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import logging\n",
    "import numpy as np\n",
    "from optparse import OptionParser\n",
    "import sys\n",
    "from time import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.datasets import fetch_20newsgroups\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.feature_extraction.text import HashingVectorizer\n",
    "from sklearn.feature_selection import SelectFromModel\n",
    "from sklearn.feature_selection import SelectKBest, chi2\n",
    "from sklearn.linear_model import RidgeClassifier\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.svm import LinearSVC\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.linear_model import Perceptron\n",
    "from sklearn.linear_model import PassiveAggressiveClassifier\n",
    "from sklearn.naive_bayes import BernoulliNB, MultinomialNB\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.neighbors import NearestCentroid\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.utils.extmath import density\n",
    "from sklearn import metrics\n",
    "\n",
    "# Display progress logs on stdout\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "                    format='%(asctime)s %(levelname)s %(message)s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Option at 0x1fa45a66d30: --filtered>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# parse commandline arguments\n",
    "op = OptionParser()\n",
    "op.add_option(\"--report\",\n",
    "              action=\"store_true\", dest=\"print_report\",\n",
    "              help=\"Print a detailed classification report.\")\n",
    "op.add_option(\"--chi2_select\",\n",
    "              action=\"store\", type=\"int\", dest=\"select_chi2\",\n",
    "              help=\"Select some number of features using a chi-squared test\")\n",
    "op.add_option(\"--confusion_matrix\",\n",
    "              action=\"store_true\", dest=\"print_cm\",\n",
    "              help=\"Print the confusion matrix.\")\n",
    "op.add_option(\"--top10\",\n",
    "              action=\"store_true\", dest=\"print_top10\",\n",
    "              help=\"Print ten most discriminative terms per class\"\n",
    "                   \" for every classifier.\")\n",
    "op.add_option(\"--all_categories\",\n",
    "              action=\"store_true\", dest=\"all_categories\",\n",
    "              help=\"Whether to use all categories or not.\")\n",
    "op.add_option(\"--use_hashing\",\n",
    "              action=\"store_true\",\n",
    "              help=\"Use a hashing vectorizer.\")\n",
    "op.add_option(\"--n_features\",\n",
    "              action=\"store\", type=int, default=2 ** 16,\n",
    "              help=\"n_features when using the hashing vectorizer.\")\n",
    "op.add_option(\"--filtered\",\n",
    "              action=\"store_true\",\n",
    "              help=\"Remove newsgroup information that is easily overfit: \"\n",
    "                   \"headers, signatures, and quoting.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Automatically created module for IPython interactive environment\n",
      "Usage: __main__.py [options]\n",
      "\n",
      "Options:\n",
      "  -h, --help            show this help message and exit\n",
      "  --report              Print a detailed classification report.\n",
      "  --chi2_select=SELECT_CHI2\n",
      "                        Select some number of features using a chi-squared\n",
      "                        test\n",
      "  --confusion_matrix    Print the confusion matrix.\n",
      "  --top10               Print ten most discriminative terms per class for\n",
      "                        every classifier.\n",
      "  --all_categories      Whether to use all categories or not.\n",
      "  --use_hashing         Use a hashing vectorizer.\n",
      "  --n_features=N_FEATURES\n",
      "                        n_features when using the hashing vectorizer.\n",
      "  --filtered            Remove newsgroup information that is easily overfit:\n",
      "                        headers, signatures, and quoting.\n"
     ]
    }
   ],
   "source": [
    "def is_interactive():\n",
    "    return not hasattr(sys.modules['__main__'], '__file__')\n",
    "\n",
    "# work-around for Jupyter notebook and IPython console\n",
    "argv = [] if is_interactive() else sys.argv[1:]\n",
    "(opts, args) = op.parse_args(argv)\n",
    "if len(args) > 0:\n",
    "    op.error(\"this script takes no arguments.\")\n",
    "    sys.exit(1)\n",
    "\n",
    "print(__doc__)\n",
    "op.print_help()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "if opts.all_categories:\n",
    "    categories = None\n",
    "else:\n",
    "    categories = [\n",
    "        'alt.atheism',\n",
    "        'talk.religion.misc',\n",
    "        'comp.graphics',\n",
    "        'sci.space',\n",
    "    ]\n",
    "    \n",
    "    \n",
    "if opts.filtered:\n",
    "    remove = ('headers', 'footers', 'quotes')\n",
    "else:\n",
    "    remove = ()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 20 newsgroups dataset for categories:\n",
      "['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading 20 newsgroups dataset for categories:\")\n",
    "print(categories if categories else \"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2018-02-22 21:19:45,766 INFO Downloading 20news dataset. This may take a few minutes.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data loaded\n"
     ]
    }
   ],
   "source": [
    "data_train = fetch_20newsgroups(subset='train', categories=categories,\n",
    "                                shuffle=True, random_state=42,\n",
    "                                remove=remove)\n",
    "\n",
    "data_test = fetch_20newsgroups(subset='test', categories=categories,\n",
    "                               shuffle=True, random_state=42,\n",
    "                               remove=remove)\n",
    "print('data loaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\"From: rych@festival.ed.ac.uk (R Hawkes)\\nSubject: 3DS: Where did all the texture rules go?\\nLines: 21\\n\\nHi,\\n\\nI've noticed that if you only save a model (with all your mapping planes\\npositioned carefully) to a .3DS file that when you reload it after restarting\\n3DS, they are given a default position and orientation.  But if you save\\nto a .PRJ file their positions/orientation are preserved.  Does anyone\\nknow why this information is not stored in the .3DS file?  Nothing is\\nexplicitly said in the manual about saving texture rules in the .PRJ file. \\nI'd like to be able to read the texture rule information, does anyone have \\nthe format for the .PRJ file?\\n\\nIs the .CEL file format available from somewhere?\\n\\nRych\\n\\n======================================================================\\nRycharde Hawkes\\t\\t\\t\\temail: rych@festival.ed.ac.uk\\nVirtual Environment Laboratory\\nDept. of Psychology\\t\\t\\tTel  : +44 31 650 3426\\nUniv. of Edinburgh\\t\\t\\tFax  : +44 31 667 0150\\n======================================================================\\n\", \"Subject: Re: Biblical Backing of Koresh's 3-02 Tape (Cites enclosed)\\nFrom: kmcvay@oneb.almanac.bc.ca (Ken Mcvay)\\nOrganization: The Old Frog's Almanac\\nLines: 20\\n\\nIn article <20APR199301460499@utarlg.uta.edu> b645zaw@utarlg.uta.edu (stephen) writes:\\n\\n>Seems to me Koresh is yet another messenger that got killed\\n>for the message he carried. (Which says nothing about the \\n\\nSeems to be, barring evidence to the contrary, that Koresh was simply\\nanother deranged fanatic who thought it neccessary to take a whole bunch of\\nfolks with him, children and all, to satisfy his delusional mania. Jim\\nJones, circa 1993.\\n\\n>In the mean time, we sure learned a lot about evil and corruption.\\n>Are you surprised things have gotten that rotten?\\n\\nNope - fruitcakes like Koresh have been demonstrating such evil corruption\\nfor centuries.\\n-- \\nThe Old Frog's Almanac - A Salute to That Old Frog Hisse'f, Ryugen Fisher \\n     (604) 245-3205 (v32) (604) 245-4366 (2400x4) SCO XENIX 2.3.2 GT \\n  Ladysmith, British Columbia, CANADA. Serving Central Vancouver Island  \\nwith public access UseNet and Internet Mail - home to the Holocaust Almanac\\n\"]\n",
      "2034 documents - 3.980MB (training set)\n",
      "1353 documents - 2.867MB (test set)\n",
      "4 categories\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# order of labels in `target_names` can be different from `categories`\n",
    "target_names = data_train.target_names\n",
    "\n",
    "def size_mb(docs):\n",
    "    return sum(len(s.encode('utf-8')) for s in docs) / 1e6\n",
    "\n",
    "#　list\n",
    "print(data_train.data[0:2])\n",
    "data_train_size_mb = size_mb(data_train.data)\n",
    "data_test_size_mb = size_mb(data_test.data)\n",
    "\n",
    "print(\"%d documents - %0.3fMB (training set)\" % (\n",
    "    len(data_train.data), data_train_size_mb))\n",
    "print(\"%d documents - %0.3fMB (test set)\" % (\n",
    "    len(data_test.data), data_test_size_mb))\n",
    "print(\"%d categories\" % len(categories))\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting features from the training data using a sparse vectorizer\n"
     ]
    }
   ],
   "source": [
    "y_train, y_test = data_train.target, data_test.target # 分类\n",
    "print(\"Extracting features from the training data using a sparse vectorizer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using TfidfVectorizer...\n",
      "done in 1.610643s at 2.471MB/s\n",
      "n_samples: 2034, n_features: 33809\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\administrator\\appdata\\local\\programs\\python\\python35\\lib\\site-packages\\sklearn\\feature_extraction\\text.py:1059: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  if hasattr(X, 'dtype') and np.issubdtype(X.dtype, np.float):\n"
     ]
    }
   ],
   "source": [
    "t0 = time()\n",
    "if opts.use_hashing:\n",
    "    print('use hashing...')\n",
    "    vectorizer = HashingVectorizer(stop_words='english', alternate_sign=False,\n",
    "                                   n_features=opts.n_features)\n",
    "    X_train = vectorizer.transform(data_train.data)\n",
    "else:\n",
    "    print('using TfidfVectorizer...')\n",
    "    vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5,\n",
    "                                 stop_words='english')\n",
    "    X_train = vectorizer.fit_transform(data_train.data)\n",
    "duration = time() - t0\n",
    "print(\"done in %fs at %0.3fMB/s\" % (duration, data_train_size_mb / duration))\n",
    "print(\"n_samples: %d, n_features: %d\" % X_train.shape)\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting features from the test data using the same vectorizer\n",
      "done in 1.111289s at 2.580MB/s\n",
      "n_samples: 1353, n_features: 33809\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\administrator\\appdata\\local\\programs\\python\\python35\\lib\\site-packages\\sklearn\\feature_extraction\\text.py:1059: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  if hasattr(X, 'dtype') and np.issubdtype(X.dtype, np.float):\n"
     ]
    }
   ],
   "source": [
    "print(\"Extracting features from the test data using the same vectorizer\")\n",
    "t0 = time()\n",
    "X_test = vectorizer.transform(data_test.data)\n",
    "duration = time() - t0\n",
    "print(\"done in %fs at %0.3fMB/s\" % (duration, data_test_size_mb / duration))\n",
    "print(\"n_samples: %d, n_features: %d\" % X_test.shape)\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['00', '000', '0000', '00000', '000000', '000005102000', '000021', '000062david42', '0000vec', '0001']\n"
     ]
    }
   ],
   "source": [
    "# mapping from integer feature name to original token string\n",
    "if opts.use_hashing:\n",
    "    feature_names = None\n",
    "else:\n",
    "    feature_names = vectorizer.get_feature_names()\n",
    "\n",
    "print(feature_names[0:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 特征筛选\n",
    "if opts.select_chi2:\n",
    "    print(\"Extracting %d best features by a chi-squared test\" %\n",
    "          opts.select_chi2)\n",
    "    t0 = time()\n",
    "    ch2 = SelectKBest(chi2, k=opts.select_chi2)\n",
    "    X_train = ch2.fit_transform(X_train, y_train)\n",
    "    X_test = ch2.transform(X_test)\n",
    "    if feature_names:\n",
    "        # keep selected feature names\n",
    "        feature_names = [feature_names[i] for i\n",
    "                         in ch2.get_support(indices=True)]\n",
    "    print(\"done in %fs\" % (time() - t0))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['00' '000' '0000' '00000' '000000' '000005102000' '000021'\n",
      " '000062david42' '0000vec' '0001']\n"
     ]
    }
   ],
   "source": [
    "if feature_names:\n",
    "    feature_names = np.asarray(feature_names)\n",
    "    \n",
    "print(feature_names[0:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 过滤太长的内容\n",
    "def trim(s):\n",
    "    \"\"\"Trim string to fit on terminal (assuming 80-column display)\"\"\"\n",
    "    return s if len(s) <= 80 else s[:77] + \"...\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def benchmark(clf):\n",
    "    print('_' * 80)\n",
    "    print(\"Training: \")\n",
    "    print(clf)\n",
    "    t0 = time()\n",
    "    clf.fit(X_train, y_train)\n",
    "    train_time = time() - t0\n",
    "    print(\"train time: %0.3fs\" % train_time)\n",
    "    \n",
    "    t0 = time()\n",
    "    pred = clf.predict(X_test)\n",
    "    test_time = time() - t0\n",
    "    print(\"test time:  %0.3fs\" % test_time)\n",
    "    \n",
    "    score = metrics.accuracy_score(y_test, pred)\n",
    "    print(\"accuracy:   %0.3f\" % score)\n",
    "    \n",
    "    if hasattr(clf, 'coef_'):\n",
    "        print(\"dimensionality: %d\" % clf.coef_.shape[1])\n",
    "        print(\"density: %f\" % density(clf.coef_))\n",
    "        \n",
    "        # 利用这里的top10 的特征 作为关键词\n",
    "        if opts.print_top10 and feature_names is not None:\n",
    "            print(\"top 10 keywords per class:\")\n",
    "            for i, label in enumerate(target_names):\n",
    "                top10 = np.argsort(clf.coef_[i])[-10:]\n",
    "                print(trim(\"%s: %s\" % (label, \" \".join(feature_names[top10]))))\n",
    "        print()\n",
    "    \n",
    "    if opts.print_report:\n",
    "        print(\"classification report:\")\n",
    "        print(metrics.classification_report(y_test, pred,\n",
    "                                            target_names=target_names))\n",
    "        \n",
    "    \n",
    "    \n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
